import os
import json
import wandb
import torch
from torchvision import datasets
from torchvision import transforms
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger
from datamodules import get_datamodule
from models import LitResnet, Litreconstruction, LitResnet_SVM
from pytorch_lightning.callbacks import ModelCheckpoint

from utils import read_yaml
from settings import DSV_dir

import argparse
 
parser = argparse.ArgumentParser(description='Configuration for training.')

# Optimization arguments
# global argumetns
parser.add_argument('--yaml', type=str, default='config.yaml', help='name of yaml file')

parser.add_argument('--lr_image', type=float, default=100, help='Learning rate for image optimization.')
parser.add_argument('--lr_lambda', type=float, default=0.0001, help='Learning rate for lambda optimization.')
parser.add_argument('--momentum_image', type=float, default=0, help='momentum for image optimization.')
parser.add_argument('--momentum_lambda', type=float, default=0, help='momentum for lambda optimization.')
parser.add_argument('--loss_weight', type=float, default=0.0005, help='lr for overall optimization.')
parser.add_argument('--stationarity_weight', type=float, default=0.001, help='lr for overall optimization.')
parser.add_argument('--policy', default='cutout, crop, translation, flip, noise', type=str, help='policy to apply (default: flip,crop)')
parser.add_argument('--noise_ratio', type=float, default=0.0, help='Rate for verifying images')
parser.add_argument('--negative_ratio', type=float, default=0.01, help='Rate for verifying images')
parser.add_argument('--translation_ratio', type=float, default=0.25, help='if 1, only use target labels, if 0, only use knowledge distillation labels')
parser.add_argument('--cutout_ratio', type=float, default=0.15, help='if 1, only use target labels, if 0, only use knowledge distillation labels')
parser.add_argument('--temperature', type=float, default=5, help='Rate for cross-entropy loss.')
parser.add_argument('--temperature1', type=float, default=1, help='Rate for cross-entropy loss.')
parser.add_argument('--x_normal_std', type=float, default=0.02, help='Learning rate for image optimization.')
parser.add_argument('--l_normal_std', type=float, default=0.05, help='Learning rate for image optimization.')
parser.add_argument('--UPSCALE_cycle', type=int, default=300, help='upscale once per this cycle')
parser.add_argument('--crop_size', type=int, default=4, help='upscale once per this cycle')
parser.add_argument('--angle', type=int, default=1, help='upscale once per this cycle')
parser.add_argument('--num_samples', type=int, default=1, help='upscale once per this cycle')
parser.add_argument('--par_mult', type=float, default=1, help='upscale once per this cycle')
parser.add_argument('--stationarity_rate', type=float, default=0.1, help='Rate for cross-entropy loss.')
parser.add_argument('--primal_rate', type=float, default=1, help='upscale once per this cycle')
parser.add_argument('--aug_stationarity_rate', type=float, default=0.5, help='Rate for cross-entropy loss.')
parser.add_argument('--aug_primal_rate', type=float, default=0.01, help='upscale once per this cycle')
parser.add_argument('--weight_decay_x', type=float, default=0, help='upscale once per this cycle')
parser.add_argument('--weight_decay_l', type=float, default=0, help='upscale once per this cycle')
## for additional loss
parser.add_argument('--tv_scale', type=float, default=1, help='upscale once per this cycle')
parser.add_argument('--alpha_scale', type=float, default=1e-5, help='upscale once per this cycle')
## For Training SVM
parser.add_argument('--aug_ratio', type=float, default=0, help='if 1, only use target labels, if 0, only use knowledge distillation labels')
parser.add_argument('--lr_retrain', type=float, default=0.01, help='Learning rate for optimization.')
parser.add_argument('--epoch_best', type=int, default=500, help='total epoch for SVM')
parser.add_argument('--batch_size', type=int, default=10, help='batch size for SVM')
parser.add_argument('--patience', type=int, default=100, help='batch size for SVM')

args = parser.parse_args()

config = read_yaml(args.yaml)


checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath= config["CHECKPATH"],
    filename= '{epoch}-{val_loss:.2f}-WRN_CIFAR100',
    mode='min',    
    every_n_epochs=10,
    save_last=False 
)

datamodule = get_datamodule(config = config)

dataset_class = getattr(datasets, config['DATASET'])
dataset_instance = dataset_class(root=config['DATAPATH'], download=True)

if config['DATASET'] == 'SVHN':
    classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    config['NUM_CLASSES'] = len(classes)
    config['CLASS_NAME'] = classes
elif config['DATASET'] == 'CelebA':
    config['CLASS_NAME']= dataset_instance.attr_names[:-1]
    config['NUM_CLASSES'] = len(dataset_instance.attr_names) - 1
elif config['DATASET'] == 'ImageNet':
    import json
    with open('imagenet_labels.json') as json_file:
        imagenet_labels = json.load(json_file)
    config['CLASS_NAME'] = imagenet_labels
    config['NUM_CLASSES'] = len(imagenet_labels)
else:
    config['NUM_CLASSES'] = len(dataset_instance.classes)
    config['CLASS_NAME'] = dataset_instance.classes

with open('imagenet_labels.json') as json_file:
    imagenet_labels = json.load(json_file)
config['ORIGINAL_SHAPE'] = (3, 32, 32)
wandb.run = None

config['ORIGINAL_SHAPE'] = (3, 32, 32)
if config['DATASET'] == 'CelebA':
    config['ORIGINAL_SHAPE'] = (3, 224, 224)
elif config['DATASET'] == 'ImageNet':
    config['ORIGINAL_SHAPE'] = (3, 224, 224)
config['ORIGINAL_SHAPE'] = (3, 224, 224)
config['SHAPE'] = (3, 224, 224)

import json
with open('imagenet_labels.json') as json_file:
    imagenet_labels = json.load(json_file)
config['CLASS_NAME'] = imagenet_labels
config['NUM_CLASSES'] = len(imagenet_labels)

wandb.init(project='1111')

wandb_logger = WandbLogger(project='11112', log_hyperparams='all')
checkpoint_name = wandb.run.name



is_data_mean = config.get('DATA_MEAN', False)

if config['MODE'] == 'RECONSTRUCT':

    DSV_checkpoint = ModelCheckpoint(
        dirpath= DSV_dir,
        filename= wandb.run.name,
        save_top_k=1,  # Save only the best model
        save_last=False,
    )
    
    model = Litreconstruction(args=args, config=config)

    trainer = Trainer(
        max_epochs= args.UPSCALE_cycle,
        accelerator="auto",
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        logger=wandb_logger,
        callbacks=[EarlyStopping(monitor='my_acc', patience=200, verbose=False, mode='max'), DSV_checkpoint],
        gradient_clip_val=0.00001, gradient_clip_algorithm="value")
    
    trainer.fit(model, datamodule)

elif config['MODE'] == 'SVM':
    trainer = Trainer(
        max_epochs= 5000,
        accelerator="auto",
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        logger=wandb_logger,
        callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=1), checkpoint_callback],
        log_every_n_steps=10,
    )
    model = LitResnet_SVM(args, config)
    trainer.fit(model)
    trainer.test(model)

elif config['MODE'] == 'BASE':
    trainer = Trainer(
        max_epochs=200,
        accelerator="auto",
        devices= 4 if torch.cuda.is_available() else None,  # limiting got iPython runs
        logger=wandb_logger,
        callbacks=[LearningRateMonitor(logging_interval="step"), TQDMProgressBar(refresh_rate=10)],
    )
    model = LitResnet(config=config)
    if config.get('pretrained', False) == False:
        trainer.fit(model, datamodule)
        trainer.test(model, datamodule)
    else:
        print("test")
        trainer.test(model, datamodule)

else:
    print("Please check the mode in config.yaml")
    exit(0)